{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../learner/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"使用lenet训练mnist.\"\"\"\n",
    "import logging\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras.layers as layer\n",
    "from tensorflow.keras import Model\n",
    "import matplotlib.pyplot as plt\n",
    "from functools import partial\n",
    "from abc import ABC\n",
    "\n",
    "from tf_learner import Learner, set_seed, Metrics, LrScheduler, Callback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# center loss\n",
    "\n",
    "class CenterLoss(object):\n",
    "    \"\"\"使用parameter封装训练center.\n",
    "    center_loss = ||feature_embedding - center_embedding||**2\n",
    "    \"\"\"\n",
    "    def __init__(self, cls_num, feature_dim):\n",
    "        super().__init__()\n",
    "        self.cls_num = cls_num\n",
    "        self.featur_num = feature_dim\n",
    "        self.center = tf.Variable(tf.random.normal((cls_num, feature_dim), 0.0, 1.0))\n",
    "        self.is_first = True\n",
    "\n",
    "    def __call__(self, ys, xs):\n",
    "        \"\"\"\n",
    "        if self.is_first:\n",
    "            self.center = self.center.to(xs.device)\n",
    "            self.is_first = False\n",
    "        \"\"\"\n",
    "        center_exp = tf.gather(self.center, tf.cast(ys, tf.int32))\n",
    "        count = tf.histogram_fixed_width(tf.cast(ys, tf.float32), value_range=(0, self.cls_num-1), nbins=self.cls_num)\n",
    "        count_dis = tf.gather(count, tf.cast(ys, tf.int32)) + 1\n",
    "        loss = tf.reduce_sum(tf.reduce_sum((xs - center_exp) ** 2, axis=1) / 2.0 / tf.cast(count_dis, tf.float32))\n",
    "        return loss\n",
    "\n",
    "\n",
    "class MyLoss(object):\n",
    "    def __init__(self, cls_num: int = 10, feature_dim: int = 2):\n",
    "        super(MyLoss, self).__init__()\n",
    "        self.loss = tf.losses.SparseCategoricalCrossentropy()\n",
    "        self.center_loss = CenterLoss(cls_num, feature_dim)\n",
    "\n",
    "    def __call__(self, targets, outputs):\n",
    "        return self.loss(targets, outputs[1]) + self.center_loss(targets, outputs[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize callback\n",
    "class TrainCallback(Callback):\n",
    "    def __init__(self):\n",
    "        self.embeddings = []    # 2维的embedding，用于可视化\n",
    "        self.labels = []\n",
    "\n",
    "    def on_epoch_begin(self, info: dict):\n",
    "        # 每次epoch开始时，清空之前的数据\n",
    "        self.embeddings.clear()\n",
    "        self.labels.clear()\n",
    "\n",
    "    def on_loss_begin(self, info: dict):\n",
    "        # 获取每个batch得到的向量\n",
    "        self.embeddings.append(info[\"outputs\"][0])\n",
    "        self.labels.append(tf.cast(info[\"y\"][0], tf.int32))\n",
    "\n",
    "    def on_metric_begin(self, info: dict):\n",
    "        info[\"orig_outputs\"] = info[\"outputs\"]\n",
    "        info[\"outputs\"] = info[\"outputs\"][1]\n",
    "\n",
    "    def on_epoch_end(self, info: dict):\n",
    "        # 每个epoch结束时，对向量进行可视化\n",
    "        embeddings = tf.concat(self.embeddings, axis=0)\n",
    "        labels = tf.concat(self.labels, axis=0)\n",
    "        self.visualize(embeddings.numpy(), labels.numpy(), info[\"epoch\"])\n",
    "\n",
    "    @staticmethod\n",
    "    def visualize(feat, labels, epoch):\n",
    "        plt.ion()\n",
    "        c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',\n",
    "             '#ff00ff', '#990000', '#999900', '#009900', '#009999']\n",
    "        plt.clf()\n",
    "        for i in range(10):\n",
    "            plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=c[i])\n",
    "        plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')\n",
    "        plt.xlim(xmin=-2.5, xmax=2.5)\n",
    "        plt.ylim(ymin=-2.5, ymax=2.5)\n",
    "        plt.text(-7.8, 7.3, \"epoch=%d\" % epoch)\n",
    "        plt.savefig('/tmp/epoch=%d.jpg' % epoch)\n",
    "        plt.draw()\n",
    "        plt.pause(0.001)\n",
    "\n",
    "\n",
    "class ValidCallback(Callback):\n",
    "    def on_metric_begin(self, info: dict):\n",
    "        # 由于输出了两个数据，但用于计算metric的，只是第2个元素\n",
    "        info[\"orig_outputs\"] = info[\"outputs\"]\n",
    "        info[\"outputs\"] = info[\"outputs\"][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model \n",
    "class MyModel(Model, ABC):\n",
    "    \"\"\"LeNet.\"\"\"\n",
    "    def __init__(self):\n",
    "        super(MyModel, self).__init__()\n",
    "        self.conv1 = layer.Conv2D(filters=6, kernel_size=5)\n",
    "        self.relu1 = layer.PReLU(alpha_initializer=tf.constant_initializer(0.25))\n",
    "        self.pool1 = layer.MaxPool2D(2)\n",
    "        self.conv2 = layer.Conv2D(filters=16, kernel_size=5)\n",
    "        self.relu2 = layer.PReLU(alpha_initializer=tf.constant_initializer(0.25))\n",
    "        self.pool2 = layer.MaxPool2D(2)\n",
    "        self.fc1 = layer.Dense(120)\n",
    "        self.relu3 = layer.PReLU(alpha_initializer=tf.constant_initializer(0.25))\n",
    "        self.fc2 = layer.Dense(84)\n",
    "        self.relu4 = layer.PReLU(alpha_initializer=tf.constant_initializer(0.25))\n",
    "        self.fc3 = layer.Dense(2)\n",
    "        self.relu5 = layer.PReLU(alpha_initializer=tf.constant_initializer(0.25))\n",
    "        self.fc = layer.Dense(10)\n",
    "\n",
    "    @tf.function\n",
    "    def call(self, x, training=None, mask=None):\n",
    "        y = self.conv1(x)\n",
    "        y = self.relu1(y)\n",
    "        y = self.pool1(y)\n",
    "        y = self.conv2(y)\n",
    "        y = self.relu2(y)\n",
    "        y = self.pool2(y)\n",
    "        y = tf.reshape(y, (y.shape[0], -1))\n",
    "        y = self.fc1(y)\n",
    "        y = self.relu3(y)\n",
    "        y = self.fc2(y)\n",
    "        y = self.relu4(y)\n",
    "        y = self.fc3(y)\n",
    "        logit = self.relu5(y)\n",
    "        y = self.fc(logit)\n",
    "        return logit, tf.nn.softmax(y, axis=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset\n",
    "# 1. 数据集\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "# Add a channels dimension\n",
    "x_train = x_train[..., tf.newaxis].astype(np.float32)\n",
    "x_test = x_test[..., tf.newaxis].astype(np.float32)\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
    "valid_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train\n",
    "model = MyModel()\n",
    "learner = Learner(model=model,\n",
    "                  train_ds=train_dataset,\n",
    "                  valid_ds=valid_dataset,\n",
    "                  metrics={'accuracy': partial(Metrics.accuracy_score, just_score=False),\n",
    "                           'precision_macro': partial(Metrics.precision_score, average='macro', just_score=False),\n",
    "                           'recall_macro': partial(Metrics.recall_score, average='macro', just_score=False),\n",
    "                           'f1_macro': partial(Metrics.f1_score, average='macro', just_score=False)},\n",
    "                  loss_func=MyLoss(),\n",
    "                  optim_func=tf.optimizers.Adam,\n",
    "                  batch_size=128,\n",
    "                  lr=0.01,\n",
    "                  lr_scheduler=LrScheduler.get_step_lr(step_size=5, gamma=0.2),\n",
    "                  valid_batch=-1,\n",
    "                  device=tf.device(\"GPU\"),\n",
    "                  train_callbacks=[TrainCallback()],\n",
    "                  valid_callbacks=[ValidCallback()])\n",
    "learner.train(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
